package com.meida.common.dbinit;


import cn.hutool.core.util.ClassUtil;
import com.alibaba.fastjson.JSON;
import com.baomidou.mybatisplus.annotation.TableName;
import com.meida.common.exception.OpenException;
import com.meida.common.utils.FlymeUtils;
import com.meida.common.utils.ToolUtil;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.jdbc.ScriptRunner;
import org.apache.ibatis.jdbc.SqlRunner;
import org.springframework.beans.factory.annotation.Autowired;

import javax.sql.DataSource;
import java.io.Reader;
import java.lang.reflect.Field;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;


/**
 * @author: zyf
 * @date: 2018/12/17 15:16
 * @desc: 数据库初始化，可初始化表，校验字段，校验表名是否存在等
 */
@Slf4j
@Getter
@Setter
public abstract class DbInitializer{

    /**
     * 如果为true，则数据库校验失败会抛出异常
     */
    private Boolean fieldValidatorExceptionFlag = true;
    private String tableName;

    @Autowired
    private DataSource dataSource;

    private SqlRunner sqlRunner;

    private ScriptRunner runner;


    private List<String> tables;

    public DbInitializer() {
    }

    public DbInitializer(Boolean fieldValidatorExceptionFlag) {
        this.fieldValidatorExceptionFlag = fieldValidatorExceptionFlag;
    }

    /**
     * 初始化数据库
     */
    public void dbInit(SqlRunner sqlRunner, ScriptRunner runner, List<String> tables) {
        this.sqlRunner = sqlRunner;
        this.runner = runner;
        /**
         * 初始化表
         */
        initTable(tables);

        /**
         * 校验实体和对应表结构是否有不一致的
         */
        fieldsValidate();
    }

    /**
     * 初始化表结构
     */
    private void initTable(List<String> tableLists) {

        //校验参数
        tableName = this.getTableName();
        try {
            //判断数据库中是否有这张表，如果没有就初始化
            if (!tableLists.contains(tableName.toUpperCase()) && !tableLists.contains(tableName.toLowerCase())) {
                Reader reader = FlymeUtils.getResource("sql/" + getTableName() + ".sql");
                if (FlymeUtils.isNotEmpty(reader)) {
                    runner.runScript(reader);
                }
                log.info("初始化" + tableName + "成功！");
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /**
     * 校验实体和对应表结构是否有不一致的
     */
    private void fieldsValidate() {

        //校验参数
        String sql = this.showColumnsSql();
        if (ToolUtil.isOneEmpty(sql)) {
            if (fieldValidatorExceptionFlag) {
                throw new OpenException("初始化数据库，存在为空的字段");
            }
        }

        //检查数据库中的字段，是否和实体字段一致
        try {
            List<Map<String, Object>> tableFields = sqlRunner.selectAll(sql);
            List<String> tableFieldList = new ArrayList<>();
            if (tableFields != null && !tableFields.isEmpty()) {

                //用于保存实体中不存在的字段的名称集合
                List<String> fieldsNotInClass = new ArrayList<>();

                //用于保存数据表中不存在的字段的名称集合
                List<String> classNotInFields = new ArrayList<>();

                //反射获取字段的所有字段名称
                List<String> classFields = this.getClassFields();

                for (Map<String, Object> tableField : tableFields) {
                    String fieldName = (String) tableField.get("FIELD");
                    if (!classFields.contains(fieldName)) {
                        fieldsNotInClass.add(fieldName);
                    }
                    tableFieldList.add(fieldName);
                }
                for (String classField : classFields) {
                    if (!tableFieldList.contains(classField) && !classField.equals("serialVersionUID")) {
                        classNotInFields.add(classField);
                    }
                }
                if (!classNotInFields.isEmpty()) {
                    log.error("数据表" + tableName + "缺失字段如下：" + JSON.toJSONString(classNotInFields));
                }

                //如果集合不为空，代表有实体和数据库不一致的数据
                if (!fieldsNotInClass.isEmpty()) {
                    log.error("实体" + getEntityClass().getSimpleName() + "缺失字段如下：" + JSON.toJSONString(fieldsNotInClass));
                    if (fieldValidatorExceptionFlag) {
                        System.exit(-1);
                    }
                }
            }
        } catch (SQLException e) {
            e.printStackTrace();
        }

    }

    /**
     * 反射获取类的所有字段
     */
    private List<String> getClassFields() {
        Class<?> entityClass = this.getEntityClass();
        Field[] declaredFields = ClassUtil.getDeclaredFields(entityClass);
        Field[] parentFields = entityClass.getFields();
        ArrayList<String> filedNamesUnderlineCase = new ArrayList<>();
        for (Field declaredField : declaredFields) {
            String fieldName = declaredField.getName();
            filedNamesUnderlineCase.add(fieldName);
        }
        for (Field declaredField : parentFields) {
            String fieldName = declaredField.getName();
            filedNamesUnderlineCase.add(fieldName);
        }
        return filedNamesUnderlineCase;
    }

    /**
     * 获取表的字段
     */
    private String showColumnsSql() {
        return "SHOW COLUMNS FROM " + tableName;
    }



    /**
     * 获取表的名称
     */
    protected String getTableName() {
        TableName tableName = getEntityClass().getAnnotation(TableName.class);
        return tableName.value();
    }


    /**
     * 获取表对应的实体
     */
    protected abstract Class<?> getEntityClass();
}
